Segment Tree¶
Building the Segment Tree¶
Segment trees are used to answer range queries like find minimum, maximum between a range of indexes of an array.
Let's consider the array $[-1, 3, 4, 0, 2, 3, 1, -3]$.
graph TD
A0["Result: 4\nRange: 0-7\nIndex: 0\n "]:::round --- A1["Result: 4\nRange: 0-3\nIndex: 1\n "]:::round
A0 --- A2["Result: 3\nRange: 4-7\nIndex: 2\n "]:::round
%% Left subtree (0-3)
A1 --- B1["Result: 3\nRange: 0-1\nIndex: 3\n "]:::round
A1 --- B2["Result: 4\nRange: 2-3\nIndex: 4\n "]:::round
B1 --- C1["Result: -1\nRange: 0-0\nIndex: 7\n "]:::round
B1 --- C2["Result: 3\nRange: 1-1\nIndex: 8\n "]:::round
B2 --- C3["Result: 4\nRange: 2-2\nIndex: 9\n "]:::round
B2 --- C4["Result: 0\nRange: 3-3\nIndex: 10\n "]:::round
%% Right subtree (4-7)
A2 --- B3["Result: 3\nRange: 4-5\nIndex: 5\n "]:::round
A2 --- B4["Result: 1\nRange: 6-7\nIndex: 6\n "]:::round
B3 --- C5["Result: 2\nRange: 4-4\nIndex: 11\n "]:::round
B3 --- C6["Result: 3\nRange: 5-5\nIndex: 12\n "]:::round
B4 --- C7["Result: 1\nRange: 6-6\nIndex: 13\n "]:::round
B4 --- C8["Result: -3\nRange: 7-7\nIndex: 14\n "]:::round
%% Rounded node styling
classDef round rx:5, ry:5, stroke:#333, stroke-width:1px;
- As we can see the segment tree consists of $2n-1$ nodes where $n$ is the size of the original array.
- The first node contains the result of the index range $(0$ to $7)$.
- Taking
mid = (0+7) // 2. - The left child contains the result of the index range $(0$ to mid$)$
- The right child contains the result of the index range $($mid$+1$ to $7)$
- Taking
- The position of the left and right child in the segment tree array is $(2n+1)$ and $(2n+2)$ respectively, where $n$ is the position of the parent.
Querying the Segment Tree¶
There are 3 conditions:
- Complete overlap: The range for which we need to find the answer completely overlaps the range of the current node. In this case we return the value of the node.
- Partial overlap: The given range partially overlaps the current node. In this case we check both the left and right child of the current node until we find a node whose range is completely overlapped by the given range.
- No overlap: In this case we return a very large number (if min query) or a very small number (if max query).
Defining the segment tree class¶
In [1]:
from typing import List
class SegmentTree:
def __init__(self, data_: List[int], key_: str):
self.key = key_
self.data = data_
self.tree = [0] * ((len(data_) << 1) - 1)
self.func = None
if key_ == 'max': self.func = max
elif key == 'min': self.func = min
else: raise ValueError("Invalid key function. Use 'max' or 'min'")
def __builder(self, index_: int, start_: int, end_: int):
if start_ == end_:
self.tree[index_] = self.data[start_]; return
mid, index = (start_ + end_) >> 1, (index_ << 1) + 1
self.__builder(index, start_, mid)
self.__builder(index + 1, mid + 1, end_)
self.tree[index_] = self.func(
self.tree[index], self.tree[index + 1])
def build(self):
self.__builder(0, 0, len(self.data) - 1)
def __get_range_value(self, start_: int, end_: int, left_: int, right_: int, index_: int) -> int:
# total overlap
if start_ <= left_ and end_ >= right_:
return self.tree[index_]
# partial overlap
elif (start_ <= right_ and start_ >= left_) or (end_ <= right_ and end_ >= left_):
mid, index = (left_ + right_) >> 1, (index_ << 1) + 1
left_child = self.__get_range_value(start_, end_, left_, mid, index)
right_child = self.__get_range_value(start_, end_, mid + 1, right_, index + 1)
return self.func(left_child, right_child)
# no overlap
else: return -10 ** 6 if self.key == 'max' else 10 ** 6
def get_value_from_range(self, start_: int, end_: int) -> int:
if start_ < 0 or end_ >= len(self.data) or start_ > end_:
raise ValueError("Invalid range")
return self.__get_range_value(start_, end_, 0, len(self.data) - 1, 0)
Driver code¶
In [2]:
data = [-1, 3, 4, 0, 2, 3, 1, -3]
segment_tree = SegmentTree(data, 'max')
segment_tree.build()
In [3]:
print("Initial Segment Tree:"); print(segment_tree.tree)
Initial Segment Tree: [4, 4, 3, 3, 4, 3, 1, -1, 3, 4, 0, 2, 3, 1, -3]
In [4]:
queries = [(0, 0), (1, 3), (4, 7)]
for query in queries:
l, r = query
print(f"Max of range ({l}, {r}): {segment_tree.get_value_from_range(l, r)}")
Max of range (0, 0): -1 Max of range (1, 3): 4 Max of range (4, 7): 3